# -*- coding: utf-8 -*-
import argparse
import os

import numpy as np
from bert4keras.backend import keras, K
from bert4keras.layers import Loss, Embedding
from bert4keras.models import build_transformer_model, BERT
from bert4keras.optimizers import Adam
from bert4keras.snippets import sequence_padding, DataGenerator
from bert4keras.tokenizers import Tokenizer
from config import train_frac, dict_path, config_path, checkpoint_path, batch_size, maxlen
from flyai.data_helper import DataHelper
from flyai.framework import FlyAI
from flyai.train_helper import download
from flyai.train_helper import submit

# 导入flyai打印日志函数的库
from flyai.utils.log_helper import train_log
# 调用系统打印日志函数，这样在线上可看到训练和校验准确率和损失的实时变化曲线
from data_loader import medical_data

from path import MODEL_PATH

'''
此项目为FlyAI2.0新版本框架，数据读取，评估方式与之前不同
2.0框架不再限制数据如何读取
样例代码仅供参考学习，可以自己修改实现逻辑。
模版项目下载支持 PyTorch、Tensorflow、Keras、MXNET、scikit-learn等机器学习框架
第一次使用请看项目中的：FlyAI2.0竞赛框架使用说明.html
使用FlyAI提供的预训练模型可查看：https://www.flyai.com/models
学习资料可查看文档中心：https://doc.flyai.com/
常见问题：https://doc.flyai.com/question.html
遇到问题不要着急，添加小姐姐微信，扫描项目里面的：FlyAI小助手二维码-小姐姐在线解答您的问题.png
'''
if not os.path.exists(MODEL_PATH):
    os.makedirs(MODEL_PATH)

# 项目的超参，不使用可以删除
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--EPOCHS", default=10, type=int, help="train epochs")
parser.add_argument("-b", "--BATCH", default=32, type=int, help="batch size")
args = parser.parse_args()
# 下载预训练语言模型地址
download('/data/chinese_roberta_L-6_H-384_A-12_K-128.zip', decompression=True)


class Main(FlyAI):
    """
    项目中必须继承FlyAI类，否则线上运行会报错。
    """

    def download_data(self):
        # 根据数据ID下载训练数据
        data_helper = DataHelper()
        data_helper.download_from_ids("PsychologicalQA")

    def deal_with_data(self):
        """
        处理数据，没有可不写。
        :return:
        """
        pass

    def train(self):
        """
        训练模型，必须实现此方法
        :return:
        """

        train_data, test_data, valid_data = medical_data()
        # 有标签的数据
        num_labeled = int(len(train_data) * train_frac)
        # 无标签的数据
        unlabeled_data = [(t, 2) for t, l in train_data[num_labeled:]]
        train_data = train_data[:num_labeled]
        # train_data = train_data + unlabeled_data

        # 建立分词器
        tokenizer = Tokenizer(dict_path, do_lower_case=True)

        # 对应的任务描述
        mask_idx = 5
        desc = ['[unused%s]' % i for i in range(1, 9)]
        desc.insert(mask_idx - 1, '[MASK]')
        desc_ids = [tokenizer.token_to_id(t) for t in desc]
        pos_id = tokenizer.token_to_id(u'很')
        neg_id = tokenizer.token_to_id(u'不')

        def random_masking(token_ids):
            """对输入进行随机mask
            """
            rands = np.random.random(len(token_ids))
            source, target = [], []
            for r, t in zip(rands, token_ids):
                if r < 0.15 * 0.8:
                    source.append(tokenizer._token_mask_id)
                    target.append(t)
                elif r < 0.15 * 0.9:
                    source.append(t)
                    target.append(t)
                elif r < 0.15:
                    source.append(np.random.choice(tokenizer._vocab_size - 1) + 1)
                    target.append(t)
                else:
                    source.append(t)
                    target.append(0)
            return source, target

        class data_generator(DataGenerator):
            """数据生成器
            """

            def __iter__(self, random=False):
                batch_token_ids, batch_segment_ids, batch_output_ids = [], [], []
                for is_end, (text, label) in self.sample(random):
                    token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
                    if label != 2:
                        token_ids = token_ids[:1] + desc_ids + token_ids[1:]
                        segment_ids = [0] * len(desc_ids) + segment_ids
                    if random:
                        source_ids, target_ids = random_masking(token_ids)
                    else:
                        source_ids, target_ids = token_ids[:], token_ids[:]
                    if label == 0:
                        source_ids[mask_idx] = tokenizer._token_mask_id
                        target_ids[mask_idx] = neg_id
                    elif label == 1:
                        source_ids[mask_idx] = tokenizer._token_mask_id
                        target_ids[mask_idx] = pos_id
                    batch_token_ids.append(source_ids)
                    batch_segment_ids.append(segment_ids)
                    batch_output_ids.append(target_ids)
                    if len(batch_token_ids) == self.batch_size or is_end:
                        batch_token_ids = sequence_padding(batch_token_ids)
                        batch_segment_ids = sequence_padding(batch_segment_ids)
                        batch_output_ids = sequence_padding(batch_output_ids)
                        yield [
                                  batch_token_ids, batch_segment_ids, batch_output_ids
                              ], None
                        batch_token_ids, batch_segment_ids, batch_output_ids = [], [], []

        class CrossEntropy(Loss):
            """交叉熵作为loss，并mask掉输入部分
            """

            def compute_loss(self, inputs, mask=None):
                y_true, y_pred = inputs
                y_mask = K.cast(K.not_equal(y_true, 0), K.floatx())
                accuracy = keras.metrics.sparse_categorical_accuracy(y_true, y_pred)
                accuracy = K.sum(accuracy * y_mask) / K.sum(y_mask)
                self.add_metric(accuracy, name='accuracy')
                loss = K.sparse_categorical_crossentropy(y_true, y_pred)
                loss = K.sum(loss * y_mask) / K.sum(y_mask)
                return loss

        class PtuningEmbedding(Embedding):
            """新定义Embedding层，只优化部分Token
            """

            def call(self, inputs, mode='embedding'):
                embeddings = self.embeddings
                embeddings_sg = K.stop_gradient(embeddings)
                mask = np.zeros((K.int_shape(embeddings)[0], 1))
                mask[1:9] += 1  # 只优化id为1～8的token
                self.embeddings = embeddings * mask + embeddings_sg * (1 - mask)
                outputs = super(PtuningEmbedding, self).call(inputs, mode)
                self.embeddings = embeddings
                return outputs

        class PtuningBERT(BERT):
            """替换原来的Embedding
            """

            def apply(self, inputs=None, layer=None, arguments=None, **kwargs):
                if layer is Embedding:
                    layer = PtuningEmbedding
                return super(PtuningBERT,
                             self).apply(inputs, layer, arguments, **kwargs)

        # 加载预训练模型
        model = build_transformer_model(
            config_path=config_path,
            checkpoint_path=checkpoint_path,
            model=PtuningBERT,
            with_mlm=True
        )

        for layer in model.layers:
            if layer.name != 'Embedding-Token':
                layer.trainable = False

        # 训练用模型
        y_in = keras.layers.Input(shape=(None,))
        output = keras.layers.Lambda(lambda x: x[:, :10])(model.output)
        outputs = CrossEntropy(1)([y_in, model.output])

        train_model = keras.models.Model(model.inputs + [y_in], outputs)
        train_model.compile(optimizer=Adam(6e-4))
        train_model.summary()

        # 预测模型
        model = keras.models.Model(model.inputs, output)

        # 转换数据集
        train_generator = data_generator(train_data, batch_size)
        valid_generator = data_generator(valid_data, batch_size)
        test_generator = data_generator(test_data, batch_size)

        class Evaluator(keras.callbacks.Callback):
            def __init__(self):
                self.best_val_acc = 0.

            def on_epoch_end(self, epoch, logs=None):
                val_acc = evaluate(valid_generator)
                if val_acc > self.best_val_acc:
                    self.best_val_acc = val_acc
                    model.save_weights(os.path.join(MODEL_PATH, 'best_model_bert.weights'))
                test_acc = evaluate(test_generator)
                print(
                    u'val_acc: %.5f, best_val_acc: %.5f, test_acc: %.5f\n' %
                    (val_acc, self.best_val_acc, test_acc)
                )
                train_log(val_acc=val_acc)

        def evaluate(data):
            total, right = 0., 0.
            for x_true, _ in data:
                x_true, y_true = x_true[:2], x_true[2]
                y_pred = model.predict(x_true)
                y_pred = y_pred[:, mask_idx, [neg_id, pos_id]].argmax(axis=1)
                y_true = (y_true[:, mask_idx] == pos_id).astype(int)
                total += len(y_true)
                right += (y_true == y_pred).sum()
            return right / total

        evaluator = Evaluator()
        tb_cb = keras.callbacks.TensorBoard(log_dir="log_dir", write_images=1, histogram_freq=0)

        train_model.fit_generator(
            train_generator.forfit(),
            steps_per_epoch=len(train_generator) * 50,
            epochs=1000,
            callbacks=[evaluator, tb_cb]
        )


if __name__ == '__main__':
    main = Main()
    main.download_data()
    main.train()

    submit("p-tuning_medical", cmd="python medical_main.py")
